package xin.marcher.wind.migrate.service.impl;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.date.DatePattern;
import cn.hutool.core.date.LocalDateTimeUtil;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.session.SqlSession;
import org.springframework.stereotype.Service;
import xin.marcher.wind.migrate.config.ScrollDataSourceConfig;
import xin.marcher.wind.migrate.domain.BinLog;
import xin.marcher.wind.migrate.domain.RangeScroll;
import xin.marcher.wind.migrate.domain.entity.EtlProgressDO;
import xin.marcher.wind.migrate.domain.entity.EtlStatisticalDO;
import xin.marcher.wind.migrate.enums.BinlogType;
import xin.marcher.wind.migrate.enums.DBChannel;
import xin.marcher.wind.migrate.mapper.MigrateScrollMapper;
import xin.marcher.wind.migrate.migrate.MergeConfig;
import xin.marcher.wind.migrate.migrate.PreparedStatementUtil;
import xin.marcher.wind.migrate.migrate.ScrollProcessor;
import xin.marcher.wind.migrate.service.MigrateConfigService;
import xin.marcher.wind.migrate.service.MigrateService;
import xin.marcher.wind.migrate.util.MigrateUtil;

import javax.annotation.Resource;
import java.math.BigDecimal;
import java.sql.Date;
import java.sql.*;
import java.util.*;

/**
 * 数据同步服务实现类
 */
@Service
@Slf4j
public class MigrateServiceImpl implements MigrateService {

    @Resource
    private MigrateScrollMapper migrateScrollMapper;

    @Resource
    private ScrollDataSourceConfig scrollDataSourceConfig;

    @Resource
    private ScrollProcessor scrollProcessor;

    @Resource
    MigrateConfigService migrateConfigService;

    @Override
    public boolean migrateBat(RangeScroll scroll, List<BinLog> binLogs) {
        log.info("开始执行 migrateBat 方法，tableName=" + scroll.getTableName() + ", 本次操作 " + binLogs.size() + " 条记录");
        if (!Objects.isNull(scroll) && CollUtil.isNotEmpty(binLogs)) {
            try {
                List<Map<String, Object>> insertMaps = new ArrayList<>();
                List<Map<String, Object>> deleteMaps = new ArrayList<>();
                for (BinLog binLog : binLogs) {
                    if (BinlogType.INSERT.getValue().equals(binLog.getOperateType())) {
                        // 新增操作单独拎出来做批量新增，不然执行效率太低
                        insertMaps.add(binLog.getDataMap());
                    } else if (BinlogType.UPDATE.getValue().equals(binLog.getOperateType())) {
                        //处理一下更新的null异常对象
                        binLog.setDataMap(MigrateUtil.updateNullValue(binLog.getDataMap()));
                        update(binLog.getDataMap(), scroll);
                    } else if (BinlogType.DELETE.getValue().equals(binLog.getOperateType())) {
                        deleteMaps.add(binLog.getDataMap());
                    }
                }
                // 批量新增
                if (CollUtil.isNotEmpty(insertMaps)) {
                    MigrateUtil.removeNullValue(insertMaps);
                    insertBat(insertMaps, scroll);
                }
                // 批量删除
                if (CollectionUtils.isNotEmpty(deleteMaps)) {
                    delete(deleteMaps, scroll);
                }
            } catch (Exception e) {
                log.error("migrateBat () tableName=" + scroll.getTableName(), e);
                return false;
            }
            return true;
        }
        return false;
    }

    @Override
    public List<String> getScrollAbleTables() {
        return migrateScrollMapper.getScrollAbleTables();
    }

    @Override
    public void compensateRangeScroll(Long id, String domain) {
        EtlProgressDO etlProgressDOInfo = migrateScrollMapper.queryEtlProgressById(id);
        RangeScroll rangeScroll = new RangeScroll();
        rangeScroll.setDomain(domain);
        rangeScroll.setDomainId(migrateConfigService.getDomainId(domain));
        rangeScroll.setStartScrollId(etlProgressDOInfo.getScrollId());
        rangeScroll.setTableName(etlProgressDOInfo.getLogicModel());
        rangeScroll.setStartTime(etlProgressDOInfo.getScrollTime());
        rangeScroll.setEndTime(etlProgressDOInfo.getScrollEndTime());
        rangeScroll.setCurTicketStage(etlProgressDOInfo.getCurTicketStage());
        rangeScroll.setTicket(etlProgressDOInfo.getTicket());
        rangeScroll.setRetryFlag(true);
        rangeScroll.setRetryTimes(etlProgressDOInfo.getRetryTimes() + 1);
        rangeScroll.setPageSize(etlProgressDOInfo.getFinishRecord());
        rangeScroll.setProgressType(etlProgressDOInfo.getProgressType());
        //补偿再次发起
        scrollProcessor.scroll(rangeScroll);
    }

    @Override
    public List<EtlProgressDO> getEtlProgresses(EtlProgressDO queryCondition) {
        try {
            if (null == queryCondition) {
                // 防止传个 null 过来造成 mybatis 处理出错
                queryCondition = new EtlProgressDO();
            }

            List<EtlProgressDO> progressList = migrateScrollMapper.queryEtlProgressList(queryCondition);
            if (CollectionUtils.isNotEmpty(progressList)) {
                for (EtlProgressDO etlProgressDO : progressList) {
                    Integer startTime = Integer.valueOf(LocalDateTimeUtil.format(etlProgressDO.getScrollTime(), DatePattern.PURE_DATE_PATTERN));
                    Integer endTime = Integer.valueOf(LocalDateTimeUtil.format(etlProgressDO.getScrollEndTime(), DatePattern.PURE_DATE_PATTERN));

                    EtlStatisticalDO etlStatisticalDO = new EtlStatisticalDO();
                    etlStatisticalDO.setDomain(etlProgressDO.getDomain());
                    etlStatisticalDO.setLogicModel(etlProgressDO.getLogicModel());
                    etlStatisticalDO.setStartTime(startTime);
                    etlStatisticalDO.setEndTime(endTime);
                    // 获取已同步的数据（通过CountCacheTask分天统计计算的数据）
                    BigDecimal statisticalCount = migrateScrollMapper.getStatisticalCount(etlStatisticalDO);
                    // 如果存在已经同步的数据数量，则计算同步进度，否则设置同步进度为0%
                    if (null != statisticalCount && null != etlProgressDO.getFinishRecord()) {
                        BigDecimal progressScale = new BigDecimal(etlProgressDO.getFinishRecord()).divide(statisticalCount, 2, BigDecimal.ROUND_HALF_UP);
                        // 因为前端展示的进度条需要的是百分比的数字，所以这里把结果乘以100
                        etlProgressDO.setProgressScale(progressScale.multiply(new BigDecimal(100)));
                    } else {
                        etlProgressDO.setProgressScale(BigDecimal.ZERO);
                    }
                }
            }
            return progressList;
        } catch (Exception e) {
            log.error("getEtlProgresses方法执行出错", e);
            return new ArrayList<>();
        }
    }

    @Override
    public String queryMinScrollId(RangeScroll rangeScroll) {
        // 验证必填参数
        if (StrUtil.isNotBlank(rangeScroll.getTableName())) {
            SqlSession session = null;
            PreparedStatement pst = null;
            try {
                rangeScroll.setScrollName(MergeConfig.getSingleKey(rangeScroll.getTableName()));

                String sql = " select " + rangeScroll.getScrollName() + " from " + rangeScroll.getTableName() + " where create_time >= ?" +
                        " order by create_time asc, " + rangeScroll.getScrollName() + " asc LIMIT 1";
                // 获取指定的数据源
                session = scrollDataSourceConfig.getSqlSession(rangeScroll.getDomain(), 1);
                pst = session.getConnection().prepareStatement(sql);
                pst.setDate(1, Date.valueOf(rangeScroll.getStartTime().toLocalDate()));
                ResultSet result = pst.executeQuery();
                while (result.next()) {
                    return String.valueOf(Long.parseLong(result.getString(1)) - 1);
                }
                return null;
            } catch (Exception e) {
                log.error("queryInfoList方法执行出错", e);
                return "0";
            } finally {
                closeSqlSession(session, pst, rangeScroll.getDomain(), 1);
            }
        }
        return "0";
    }

    /**
     * 负责分页滚动数据
     *
     * @param rangeScroll 查询条件
     * @return
     */
    @Override
    @SuppressWarnings({"unchecked"})
    public List<Map<String, Object>> queryInfoList(RangeScroll rangeScroll) {
        if (StrUtil.isNotBlank(rangeScroll.getTableName()) && StrUtil.isNotBlank(rangeScroll.getStartScrollId())) {
            SqlSession session = null;
            PreparedStatement pst = null;
            try {
                String sql = "select * from " + rangeScroll.getTableName() + " where " + rangeScroll.getScrollName() + " > ? " +
                        " order by  " + rangeScroll.getScrollName() + " asc LIMIT " + rangeScroll.getPageSize();
                // 获取指定的数据连接
                session = scrollDataSourceConfig.getSqlSession(rangeScroll.getDomain(), 1);
                pst = session.getConnection().prepareStatement(sql);

                pst.setString(1, rangeScroll.getStartScrollId());
                ResultSet resultSet = pst.executeQuery();

                return converter(resultSet);
            } catch (Exception e) {
                log.error("queryInfoList 方法执行出错", e);
                return new ArrayList<>();
            } finally {
                closeSqlSession(session, pst, rangeScroll.getDomain(), 1);
            }
        }
        return new ArrayList<>();
    }

    /**
     * 批量查询数据信息
     *
     * @param scroll      数据对象
     * @param identifiers 唯一标识List
     * @param dbChannel   指向具体的BD库
     * @return
     */
    @Override
    @SuppressWarnings({"unchecked"})
    public List<Map<String, Object>> findByIdentifiers(RangeScroll scroll, List<String> identifiers, String dbChannel) {
        if (!Objects.isNull(scroll) && CollUtil.isNotEmpty(identifiers)) {
            SqlSession session = null;
            PreparedStatement pst = null;
            Integer dataSourceType = 2;
            try {
                if (DBChannel.CHANNEL_1.getValue().equals(dbChannel)) {
                    dataSourceType = 1;
                }
                session = scrollDataSourceConfig.getSqlSession(scroll.getDomain(), dataSourceType);

                if (null != session) {

                    StringBuffer sql = new StringBuffer();
                    sql.append("select * from " + scroll.getTargetTableName() + " where " + scroll.getScrollName() + " in (");
                    for (String id : identifiers) {
                        sql.append("?,");
                    }
                    String sqlStr = sql.substring(0, sql.length() - 1) + ")";

                    pst = session.getConnection().prepareStatement(sqlStr);
                    for (int i = 1; i <= identifiers.size(); i++) {
                        pst.setString(i, identifiers.get(i - 1));
                    }
                    ResultSet resultSet = pst.executeQuery();
                    return converter(resultSet);
                }
            } catch (Exception e) {
                log.error("findByIdentifiers方法执行出错", e);
                return new ArrayList<>();
            } finally {
                closeSqlSession(session, pst, scroll.getDomain(), dataSourceType);
            }
        }
        return new ArrayList<>();
    }

    /**
     * 批量新增
     *
     * @param insertMaps
     * @param scroll
     */
    private void insertBat(List<Map<String, Object>> insertMaps, RangeScroll scroll) throws Exception {
        SqlSession session = null;
        PreparedStatement pst = null;
        try {
            StringBuffer insertSql = new StringBuffer();

            StringBuffer sql = new StringBuffer();
            sql.append("insert into ").append(scroll.getTargetTableName()).append(" (");
            Map<String, Object> insertMap = insertMaps.get(0);
            // 将 key 作为写入的字段拼接, 值为当前需要写入的具体字段名称
            insertMap.keySet().forEach(key -> sql.append(key).append(","));
            String sqlStr = sql.substring(0, sql.length() - 1) + ") values (";
            insertSql.append(sqlStr);
            // 批量封装新增的数据
            for (Map<String, Object> inset : insertMaps) {
                StringBuffer insertValue = new StringBuffer();
                inset.keySet().forEach(key -> insertValue.append("?,"));
                String insertStr = insertValue.substring(0, insertValue.length() - 1) + "),(";
                insertSql.append(insertStr);
            }

            String insertBatSql = insertSql.substring(0, insertSql.length() - 2);

            // 开始加载写入的数据源信息
            session = scrollDataSourceConfig.getSqlSession(scroll.getDomain(), 2);
            pst = session.getConnection().prepareStatement(insertBatSql);
            int pos = 1;
            for (Map<String, Object> insert : insertMaps) {
                for (String key : insert.keySet()) {
                    PreparedStatementUtil.buildPerParedStatement(pst, pos, insert.get(key));
                    pos++;
                }
            }

            // 执行 sql 语句
            pst.addBatch();
            pst.executeBatch();
        } catch (Exception e) {
            log.error("批量保存数据异常", e);
        } finally {
            closeSqlSession(session, pst, scroll.getDomain(), 2);
        }
    }

    /**
     * 修改数据
     *
     * @param updateMap
     * @param scroll
     * @throws Exception
     */
    private void update(Map<String, Object> updateMap, RangeScroll scroll) throws Exception {
        SqlSession session = null;
        PreparedStatement pst = null;

        try {
            StringBuffer updateSql = new StringBuffer();
            updateSql.append("update ").append(scroll.getTargetTableName()).append(" set ");
            // 拼凑更新语句
            for (String key : updateMap.keySet()) {
                updateSql.append(key).append("=?,");
            }
            String updateStr = updateSql.substring(0, updateSql.length() - 1);

            updateStr = updateStr + " where " + scroll.getScrollName() + "=" + updateMap.get(scroll.getScrollName());

            // 开始加载写入的数据源信息
            session = scrollDataSourceConfig.getSqlSession(scroll.getDomain(), 2);
            pst = session.getConnection().prepareStatement(updateStr);

            int pos = 1;
            for (String key : updateMap.keySet()) {
                PreparedStatementUtil.buildPerParedStatement(pst, pos, updateMap.get(key));
                pos++;
            }
            // 执行sql语句
            pst.executeUpdate();
        } catch (Exception e) {
            log.error("批量更新数据异常", e);
        } finally {
            closeSqlSession(session, pst, scroll.getDomain(), 2);
        }
    }

    /**
     * 批量删除数据
     *
     * @param deleteMaps
     * @param scroll
     */
    private void delete(List<Map<String, Object>> deleteMaps, RangeScroll scroll) throws Exception {
        StringBuffer deleteSql = new StringBuffer();
        deleteSql.append("delete from " + scroll.getTargetTableName() + " where " + scroll.getScrollName() + " in (");
        // 拼凑更新语句
        for (Map<String, Object> deleteMap : deleteMaps) {
            deleteSql.append(deleteMap.get(scroll.getScrollName() + ","));
        }
        String deleteStr = deleteSql.substring(0, deleteSql.length() - 1);
        deleteStr = deleteStr + ")";
        // 执行sql语句
        executeUpdate(scroll, deleteStr);
    }

    /**
     * 对集合查询的封装出一个小方法
     *
     * @param resultSet
     * @return
     * @throws Exception
     */
    private List<Map<String, Object>> converter(ResultSet resultSet) throws Exception {
        List<Map<String, Object>> list = new ArrayList<>();

        ResultSetMetaData md = resultSet.getMetaData();
        int columnCount = md.getColumnCount();
        while (resultSet.next()) {
            Map<String, Object> rowData = new HashMap<>();
            for (int i = 1; i <= columnCount; i++) {
                Object value = resultSet.getObject(i);
                rowData.put(md.getColumnName(i), value);
            }
            list.add(rowData);
        }
        return list;
    }

    /**
     * 执行sql语句
     *
     * @param scroll
     * @param executeSql
     */
    private void executeUpdate(RangeScroll scroll, String executeSql) throws Exception {
        SqlSession session = null;
        PreparedStatement pst = null;
        try {
            // 开始加载写入的数据源信息
            session = scrollDataSourceConfig.getSqlSession(scroll.getDomain(), 2);
            pst = session.getConnection().prepareStatement(executeSql);
            // 执行语句写入
            pst.executeUpdate(executeSql);
        } catch (Exception e) {
            log.error("sql执行失败", e);
        } finally {
            closeSqlSession(session, pst, scroll.getDomain(), 2);
        }
    }

    /**
     * 关闭连接
     *
     * @param session
     * @param pst
     * @param domain
     * @param dataSourceType
     */
    private void closeSqlSession(SqlSession session, PreparedStatement pst, String domain, Integer dataSourceType) {
        if (pst != null) {
            try {
                pst.close();
            } catch (SQLException e) {
                log.error("关闭 SqlSession 执行失败", e);
            }
        }
        scrollDataSourceConfig.closeSqlSession(session, domain, dataSourceType);
    }

    private Long getDomainId(String domain) {
        return migrateConfigService.getDomainId(domain);
    }
}
